import os
import sys
import json
import argparse
import numpy as np
import math
import time
import random
import string
import h5py
from tqdm import tqdm
import webdataset as wds
import logging
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from accelerate import Accelerator
from torch.utils.data import DataLoader
sys.path.append("./mindeye2_src")

from mindeye2_src.utils import seed_everything
from mindeye2_src.models import PriorNetwork, BrainDiffusionPrior
# 注意：下面这行换成你放 VoxelVAE 的目录
from models import VoxelVAE         

seed_everything(42)
embedder_name = "ViT-L/14"
device = torch.device('cuda:0')
data_path = 'dataset'
cache_dir = 'Cache'
subj =1
new_test = True
repeat = 5
sampling_steps = 100      # DDPM 推理步数，可改

# ---------- 1. 路径 ----------
voxel_vae_path = f"./subj0{subj}_vae_eye/ckpt_300.pt"
diffusion_prior_path = f"./subj0{subj}_prior_vae_eye/ckpt_150.pt"
output_dir = f"./subj0{subj}_prior_vae_eye"
os.makedirs(output_dir, exist_ok=True)

# ---------- 2. 数据 ----------
# =============== 3. 数据集 =============== 
# 3.1 fMRI 
with h5py.File(f'{data_path}/betas_all_subj0{subj}_fp32_renorm.hdf5',  'r') as f: 
    voxels = torch.tensor(f['betas'][:]).float()  
num_voxels = voxels.shape[-1]  
 
# 3.2 CLIP 嵌入 
clip_emb_file = h5py.File(f'{data_path}/clip_embeddings.hdf5',  'r') 
clip_embeddings = clip_emb_file['embeddings']  # (N, 257, 768) 
 
# 3.3 WebDataset 
def my_split_by_node(urls): return urls

test_url = f"{data_path}/wds/subj0{subj}/new_test/0.tar"
test_data = wds.WebDataset(test_url, resampled=False, nodesplitter=my_split_by_node) \
   .shuffle(750, initial=1500, rng=random.Random(42)) \
   .decode("torch") \
   .rename(behav="behav.npy", past_behav="past_behav.npy", future_behav="future_behav.npy",
           olds_behav="olds_behav.npy") \
   .to_tuple(*["behav", "past_behav", "future_behav", "olds_behav"])
test_dl  = DataLoader(test_data, batch_size=3000, shuffle=False,
                      drop_last=False, pin_memory=True, num_workers=4, prefetch_factor=2)

# ---------- 3. 加载模型 ----------
clip_seq_dim = 257
clip_emb_dim = 768

# 3.1 VAE
vae = VoxelVAE(
    num_voxels=num_voxels,
    token_dim=clip_emb_dim,
    num_tokens=clip_seq_dim,
    hidden_dim=256,
    n_blocks=2,
    drop=0.15
).to(device)
vae.load_state_dict(torch.load(voxel_vae_path)['model'])
vae.eval()

# 3.2 Diffusion Prior
prior = PriorNetwork(
    dim=clip_emb_dim,
    depth=6,
    dim_head=48,
    heads=clip_emb_dim // 52,
    causal=False,
    num_tokens=clip_seq_dim,
    learned_query_mode="pos_emb"
)
diffusion_prior = BrainDiffusionPrior(
    net=prior,
    image_embed_dim=clip_emb_dim,
    condition_on_text_encodings=False,
    timesteps=100,
    cond_drop_prob=0.2,
    image_embed_scale=None,
).to(device)
diffusion_prior.load_state_dict(torch.load(diffusion_prior_path, map_location="cpu")['model'])
diffusion_prior.eval()

from models import CentralFoveaAttention
attn = CentralFoveaAttention(embed_dim=768, grid_size=16).to(device)
attn.load_state_dict(torch.load(voxel_vae_path)['attn'])
attn.eval()

# ---------- 4. 推理 ----------
mse = nn.MSELoss()

test_clip_emb = None
test_voxel = None
with torch.no_grad():
    for behav, _, _, _ in test_dl:
        if test_clip_emb is None:
            voxel = voxels[behav[:, 0, 5].cpu().long()]
            image_idx = behav[:, 0, 0].cpu().long()
            unique_image, sort_indices = torch.unique(image_idx, return_inverse=True)
            for im in unique_image:
                locs = torch.where(im == image_idx)[0]
                if len(locs) == 1:
                    locs = locs.repeat(3)
                elif len(locs) == 2:
                    locs = locs.repeat(2)[:3]
                assert len(locs) == 3
                if test_clip_emb is None:
                    test_clip_emb = torch.Tensor(clip_embeddings[im][None])
                    test_voxel = voxel[locs][None]
                else:
                    test_clip_emb = torch.vstack((test_clip_emb, torch.Tensor(clip_embeddings[im][None])))
                    test_voxel = torch.vstack((test_voxel, voxel[locs][None]))

    test_voxel_mean = torch.mean(test_voxel, dim=1)
    torch.save(test_voxel_mean.cpu(),
           os.path.join(output_dir, f"test_voxel_mean.pt"))
    num_test=test_voxel_mean.size(0)
    predicted_fmri = torch.zeros((num_test, num_voxels)).to(device)
    mse_list = []

    for i in tqdm(range(test_voxel_mean.size(0))):

        image_rep_i = test_clip_emb[i].unsqueeze(0).to(device)
        image_rep_i = attn(image_rep_i)
        voxel_i = test_voxel_mean[i].unsqueeze(0).to(device)
        
        pred_i_repeat = torch.zeros((repeat, 1, num_voxels)).to(device)
        for repe in range(repeat):
            pred_rep_i_repe = diffusion_prior.p_sample_loop([1, 257, 768], text_cond=dict(text_embed=image_rep_i),
                                                    cond_scale=1., timesteps=sampling_steps)
            pred_i_repe = vae.decode(pred_rep_i_repe)
            pred_i_repeat[repe] = pred_i_repe
        pred_i = torch.mean(pred_i_repeat, dim=0)
        predicted_fmri[i] = pred_i
        mse_list.append(mse(pred_i, voxel_i).item())

print(f"avg MSE = {np.mean(mse_list):.6f}")
torch.save(predicted_fmri.cpu(),
           os.path.join(output_dir, f"predicted_fmri_ep150_step{sampling_steps}_repeat{repeat}.pt"))